{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 准备数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 加载模块\n",
    "from datetime import datetime\n",
    "\n",
    "from tqdm import tqdm\n",
    "import rqdatac as rq\n",
    "\n",
    "from vnpy.trader.database import DB_TZ\n",
    "from vnpy.trader.datafeed import get_datafeed\n",
    "from vnpy.trader.constant import Exchange, Interval\n",
    "from vnpy.trader.object import HistoryRequest\n",
    "from vnpy.alpha import AlphaLab, logger"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 设置下载参数\n",
    "task_name = \"csi300\"\n",
    "index_symbol = \"000300.SSE\"\n",
    "rq_index_symbol = \"000300.XSHG\"\n",
    "\n",
    "start_date = \"2007-01-01\"\n",
    "end_date = \"2024-10-31\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 创建投研实验室\n",
    "lab = AlphaLab(f\"./lab/{task_name}\")    # 指定数据文件夹"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "c:\\Python313\\Lib\\site-packages\\rqdatac\\client.py:257: UserWarning: Your account will be expired after  279 days. Please call us at 0755-22676337 to upgrade or purchase or renew your contract.\n",
      "  warnings.warn(\"Your account will be expired after  {} days. \"\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 初始化数据服务（这里配置使用的RQData）\n",
    "datafeed = get_datafeed()\n",
    "datafeed.init()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 下载指数成分股\n",
    "data = rq.index_components(rq_index_symbol, start_date=start_date, end_date=end_date)\n",
    "\n",
    "# 转换合约代码\n",
    "index_components = {}\n",
    "for dt, rq_symbols in data.items():\n",
    "    vt_symbols: list = []\n",
    "\n",
    "    for rq_symbol in rq_symbols:\n",
    "        vt_symbol = rq_symbol.replace(\"XSHG\", \"SSE\").replace(\"XSHE\", \"SZSE\")\n",
    "        vt_symbols.append(vt_symbol)\n",
    "\n",
    "    index_components[dt.strftime(\"%Y-%m-%d\")] = vt_symbols\n",
    "\n",
    "# 保存到数据中心\n",
    "lab.save_component_data(index_symbol, index_components)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 加载指数成分股代码\n",
    "component_symbols = lab.load_component_symbols(index_symbol, start_date, end_date)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 859/859 [01:52<00:00,  7.66it/s]\n"
     ]
    }
   ],
   "source": [
    "# 转换时间格式\n",
    "start = datetime.strptime(start_date, \"%Y-%m-%d\")\n",
    "start = start.replace(tzinfo=DB_TZ)\n",
    "\n",
    "end = datetime.strptime(end_date, \"%Y-%m-%d\")\n",
    "end = end.replace(tzinfo=DB_TZ)\n",
    "\n",
    "# 除了成分股，还要下载指数数据\n",
    "task_symbols = component_symbols + [index_symbol]\n",
    "\n",
    "# 遍历下载数据\n",
    "for vt_symbol in tqdm(task_symbols):\n",
    "    symbol, exchange_str = vt_symbol.split(\".\")\n",
    "\n",
    "    req = HistoryRequest(symbol, Exchange(exchange_str), start, end, Interval.DAILY)\n",
    "    bars = datafeed.query_bar_history(req)\n",
    "\n",
    "    if bars:\n",
    "        lab.save_bar_data(bars)\n",
    "    else:\n",
    "        logger.error(f\"下载{vt_symbol}数据失败\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 添加回测参数配置\n",
    "for vt_symbol in component_symbols:\n",
    "    lab.add_contract_setting(\n",
    "        vt_symbol,\n",
    "        long_rate=5/10000,\n",
    "        short_rate=10/10000,\n",
    "        size=1,\n",
    "        pricetick=0.0001,\n",
    "    )"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
